package com.comtom.soft.thrift.client;

import java.net.InetSocketAddress;

import org.apache.commons.pool.BasePoolableObjectFactory;
import org.apache.thrift.async.TAsyncClient;
import org.apache.thrift.async.TAsyncClientFactory;
import org.apache.thrift.async.TAsyncClientManager;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TMultiplexedProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TNonblockingSocket;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.comtom.soft.thrift.exception.ThriftException;

/**
 * 连接池,thrift-client for spring
 */
public class AyncThriftClientPoolFactory extends BasePoolableObjectFactory<AsyncClientAdapter> {

	private Logger logger = LoggerFactory.getLogger(getClass());
	
	private final ThriftServerProvider serverAddressProvider;
	private final Class<TAsyncClientFactory<TAsyncClient>> factoryClass;
	private PoolOperationCallBack callback;

	protected AyncThriftClientPoolFactory(ThriftServerProvider addressProvider, Class<TAsyncClientFactory<TAsyncClient>> factoryClass) throws Exception {
		this.serverAddressProvider = addressProvider;
		this.factoryClass = factoryClass;
	}

	public AyncThriftClientPoolFactory(ThriftServerProvider addressProvider, Class<TAsyncClientFactory<TAsyncClient>> factoryClass,
			PoolOperationCallBack callback) throws Exception {
		this.serverAddressProvider = addressProvider;
		this.factoryClass = factoryClass;
		this.callback = callback;
	}

	public static interface PoolOperationCallBack {
		// 销毁client之前执行
		void destroy(AsyncClientAdapter client);

		// 创建成功是执行
		void make(AsyncClientAdapter client);
	}

	@Override
	public void destroyObject(AsyncClientAdapter client) throws Exception {
		if (callback != null) {
			try {
				callback.destroy(client);
			} catch (Exception e) {
				logger.warn("destroyObject:{}", e);
			}
		}
		logger.info("destroyObject:{}", client);
		client.getTransport().close();
	}

	@Override
	public void activateObject(AsyncClientAdapter client) throws Exception {
	}

	@Override
	public void passivateObject(AsyncClientAdapter client) throws Exception {
	}

	@Override
	public boolean validateObject(AsyncClientAdapter client) {
		return true;
	}

	@Override
	public AsyncClientAdapter makeObject() throws Exception {
		InetSocketAddress address = serverAddressProvider.selector();
		if(address==null){
			logger.warn("can not get client from zookeeper.");
			throw new ThriftException("can not get client from zookeeper.");
		}
		TNonblockingTransport transport = new TNonblockingSocket(address.getHostName(), address.getPort());
		TAsyncClientManager clientManager = new TAsyncClientManager();
		TProtocolFactory tProtocolFactory=null;
		if(address instanceof ServiceSocketAddress){
			final ServiceSocketAddress address2=(ServiceSocketAddress) address;
			tProtocolFactory=new TProtocolFactory() {
				private static final long serialVersionUID = 1L;
				@Override
				public TProtocol getProtocol(TTransport trans) {
					return new TMultiplexedProtocol(new TCompactProtocol(trans),address2.getServiceName());
				}
			};
		}else{
			tProtocolFactory=new TCompactProtocol.Factory();
		}
		Class<?>[] paramTypes = { TAsyncClientManager.class, TProtocolFactory.class};
		Object[] params = {clientManager, tProtocolFactory}; // 方法传入的参数
		TAsyncClient client= factoryClass.getConstructor(paramTypes).newInstance(params).getAsyncClient(transport);
		
		AsyncClientAdapter adapter=new AsyncClientAdapter();
		adapter.setAsyncClient(client);
		adapter.setTransport(transport);
		
		if (callback != null) {
			try {
				callback.make(adapter);
			} catch (Exception e) {
				logger.warn("makeObject:{}", e);
			}
		}
		return adapter;
	}

}
